from transformers import AutoTokenizer, AutoModelForCausalLM
from argparse import ArgumentParser
import torch
from tqdm import tqdm
import os
from torch.utils.data import DataLoader, Dataset
from tools import DynamicDataset
import jsonlines
from peft import PeftModel
import json


def generate_and_tokenize_prompt(usrs):
    chats = [[{"role": "user", "content": usr}] for usr in usrs]
    if "Orca" in opt.base_model:
        system = "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
        messages = [f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{chat[0]['content']}<|im_end|>\n<|im_start|>assistant" for chat in chats]
        #messages = [tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False, chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + <|im_end|> }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + <|im_end|> }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n'  + message['content'] + <|im_end|> }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}") for chat in chats]
    else:
        messages = [tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) for chat in chats]
    
    tokenized_messages = tokenizer(messages, padding=True, return_tensors="pt")
    input_ids = tokenized_messages.input_ids
    attention_mask = tokenized_messages.attention_mask

    return input_ids.cuda(), attention_mask.cuda()


def generate_user_prompt(opt, data):
    users = []
    for d in data:
        Q = d['Q']
        O = " ".join(d['O'])
        if not opt.wo_instruction:
            if opt.mode == "cues":
                if len(d['O']) == 0:
                    users.append(f"Your are given a question. Please answer this question. Please give your reasoning cues first, then give the final answer. Please follow the format like \"Reasoning cues: ___. Therefore, the answer is ___.\"\nQuestion: {Q}")
                else:
                    users.append(f"Your are given a question togehter with some options, you SHOULD and MUST answer this question by choosing an option. You can only choose only one option. Please give your reasoning cues first, then give the final answer. Please follow the format like \"Reasoning cues: ___. Therefore, the answer is ___.\"\nQuestion: {Q}\nOptions: {O}")
            else:
                assert opt.mode == "direct"
                if len(d['O']) == 0:
                    users.append(instructions['direct'].format(Q, "").replace("Options: ", ""))
                else:
                    users.append(instructions['direct'].format(Q, O))
        else:
            if opt.mode == "direct":
                if len(d['O']) == 0:
                    users.append(f"{'Question: ' if 'Passage' not in Q else ''}{d['Q']}")
                else:
                    users.append(f"{'Question: ' if 'Passage' not in Q else ''}{d['Q']}\nOptions: {O}")
            else:
                assert opt.mode == "pure"
                if len(d['O']) == 0:
                    users.append(f"{Q}")
                else:
                    users.append(f"{Q}\n{O}")
    return users


parser = ArgumentParser()
parser.add_argument('--input_dir', type=str, default='./data/')
parser.add_argument('--base_model', type=str, default='')
parser.add_argument('--lora', type=bool, default=False)
parser.add_argument('--lora_dir', type=str, default='./lora/gpt-4')
parser.add_argument('--output_dir', type=str, default='./output/commonsense')
parser.add_argument('--instruction', type=str, default='./data/instruction.json')
parser.add_argument('--wo_instruction', type=bool, default=False)
parser.add_argument('--template', type=str, default='./data/chat_template.json')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--temperature', type=float, default=0.0000001)
parser.add_argument('--max_gen_len', type=int, default=512)
parser.add_argument('--mode', type=str, default='direct', choices=['direct', 'rag', 'cues'])
parser.add_argument('--cues', type=bool, default=False)
opt = parser.parse_args()
print(opt)

if not os.path.exists(opt.output_dir):
    os.makedirs(opt.output_dir)

if "gemma" in opt.base_model:
    model = AutoModelForCausalLM.from_pretrained(opt.base_model, load_in_8bit=False, device_map="auto", attn_implementation='eager')
else:
    model = AutoModelForCausalLM.from_pretrained(opt.base_model, load_in_8bit=False, device_map="auto")

if opt.lora:
    model = PeftModel.from_pretrained(model, opt.lora_dir)
try:
    tokenizer = AutoTokenizer.from_pretrained(opt.base_model, padding_side="left")
except:
    path = json.load(open(f"{opt.base_model}/config.json"))['_name_or_path']
    tokenizer = AutoTokenizer.from_pretrained(path, padding_side="left")

tokenizer.padding_side = 'left'
print(f"[Padding Side]: {tokenizer.padding_side}")
tokenizer.pad_token_id = (0)
chat_templates = json.load(open(opt.template, "r"))
instructions = json.load(open(opt.instruction, "r"))

def process_file(opt, data_name):
    print(f"[Processing]\t{data_name}")
    data = [d for d in jsonlines.open(os.path.join(opt.input_dir, data_name))]
    system = "..."
    systems = [system for _ in range(len(data))]
    users = generate_user_prompt(opt, data)    

    if data_name in [
        "high_school_european_history_test.jsonl",
        "professional_law_test.jsonl",
        "high_school_us_history_test.jsonl",
        "high_school_world_history_test.jsonl",
        "international_law_test.jsonl",
        "world_religions_test.jsonl",
        "prehistory_test.jsonl",
        "college_medicine_test.jsonl"
    ] or "race" in opt.input_dir:
        batch_size = opt.batch_size // 2
    elif "agieval" in opt.input_dir:
        batch_size = opt.batch_size // 3
    else:
        batch_size = opt.batch_size

    output_path = os.path.join(opt.output_dir, f"{data_name.split('.')[0]}_{opt.mode}.jsonl")
    if os.path.exists(output_path):
        fo = jsonlines.open(output_path, mode='a')
        num = len([d for d in jsonlines.open(output_path, 'r')])
    else:
        fo = jsonlines.open(output_path, mode='w')
        num = 0
    systems, users = systems[num:], users[num:]

    DATA = DynamicDataset(systems, users)
    LOADER = DataLoader(DATA, batch_size=batch_size, shuffle=False)

    model.eval()
    index = 0

    for batch in LOADER:
        # print(index)
        _, usrs = batch
        input_ids, attention_mask = generate_and_tokenize_prompt(usrs)
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                temperature=opt.temperature, 
                max_new_tokens=opt.max_gen_len,
                do_sample=True,
                return_dict_in_generate=True,
                num_return_sequences=1,
                pad_token_id=tokenizer.eos_token_id)
        sentence_ids = outputs.sequences
        sentences = [tokenizer.decode(s, clean_up_tokenization_spaces=True) for s in sentence_ids]
        for sentence in sentences:
            data[index][opt.mode] = sentence
            fo.write(data[index])
            index += 1


files = [f for f in os.listdir(opt.input_dir) if f.endswith("jsonl")]
for f in files:
    process_file(opt, f)

